[core] Unify validation_step_outputs to always return list-of-lists#15470
[core] Unify validation_step_outputs to always return list-of-lists#15470XuesongYang wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR standardizes ModelPT.validation_step_outputs / test_step_outputs to use a consistent “list-of-lists” shape, simplifying subclass logic by removing single-vs-multi-dataloader branching and improving epoch-end dispatch/guards.
Changes:
- Updated
ModelPTepoch-end logic to dispatch based onlen(outputs)(single vs multi dataloader) and to skip/guard empty per-dataloader outputs. - Normalized validation dataloader storage to
List[DataLoader]inresolve_validation_dataloaders()and refactored many modelvalidation_step/test_stepimplementations to always append via[dataloader_idx]. - Updated unit tests and added a regression test to ensure
multi_validation_epoch_end/multi_test_epoch_endare not called when all outputs are empty.
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/core_ptl/test_ptl_stateless_timer.py | Updates test model hooks to the new list-of-lists output shape and adds an empty-epoch regression test. |
| tests/core_ptl/check_for_ranks.py | Switches test model to append outputs via validation_step_outputs[dataloader_idx] and uses multi_validation_epoch_end. |
| tests/collections/common/test_ema.py | Updates validation/test steps to append via [dataloader_idx] and uses multi_validation_epoch_end. |
| nemo/utils/model_utils.py | Wraps single validation dataloaders into a list to normalize _validation_dl shape. |
| nemo/core/classes/modelPT.py | Implements the unified list-of-lists output cache, updates epoch-end dispatch and empty-output guards. |
| nemo/collections/tts/models/magpietts_preference_optimization.py | Removes single-vs-multi branching in validation output accumulation; adjusts epoch-end logic for the new shape. |
| nemo/collections/tts/models/magpietts.py | Updates validation accumulation and epoch-end collection to use validation_step_outputs[0] consistently. |
| nemo/collections/tts/models/fastpitch.py | Updates validation accumulation and epoch-end processing to use the new output structure. |
| nemo/collections/tts/g2p/models/t5.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/tts/g2p/models/ctc.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/audio/models/audio_to_audio.py | Removes branching on dataloader count; simplifies callback setup in line with _validation_dl normalization. |
| nemo/collections/asr/models/transformer_bpe_models.py | Simplifies multi-epoch-end logic to assume per-dataloader outputs (base class iterates dataloaders). |
| nemo/collections/asr/models/ssl_models.py | Removes branching on dataloader count and fixes test_step to append to test_step_outputs. |
| nemo/collections/asr/models/sortformer_diar_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/slu_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/rnnt_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/label_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/hybrid_rnnt_ctc_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/ctc_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/classification_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/aed_multitask_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _get_num_dataloaders(self, tag: str = 'val'): | ||
| if tag == 'val': | ||
| num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1 | ||
| num_dataloaders = len(self._validation_dl) if self._validation_dl else 1 | ||
| elif tag == 'test': | ||
| num_dataloaders = len(self._test_dl) if isinstance(self._test_dl, List) else 1 | ||
| num_dataloaders = len(self._test_dl) if self._test_dl else 1 | ||
| else: |
There was a problem hiding this comment.
_get_num_dataloaders() now returns 1 when _validation_dl is an empty list. This changes the meaning from “number of configured dataloaders” to “at least 1”, which can cause _setup_metrics() to initialize metrics for a non-existent dataloader. Also, isinstance(self._test_dl, List) uses typing.List, which raises TypeError at runtime for isinstance checks; this should be replaced with a runtime type like (list, tuple) (and likely the same empty-list handling as for validation).
c422e93 to
53da4f2
Compare
53da4f2 to
75fc50d
Compare
validation_step_outputs and test_step_outputs now always return a list of lists (one inner list per dataloader), eliminating if/else branching in every subclass that handles single-vs-multi dataloader shapes. - validation_step_outputs property: returns [[] for _ in range(num_dl)] - on_validation/test_epoch_end: len()==1 dispatch, all(len(o)==0 ...) empty guard, skip empty DL buckets in multi-DL loop - Normalize _validation_dl to Optional[List[DataLoader]] in resolver - 15 model files: self.validation_step_outputs[dataloader_idx].append() - TTS models: RuntimeError guard for single-DL assumption - Test models: override multi_validation_epoch_end, not on_*_epoch_end - Bug fix: ssl_models test_step appended to wrong outputs list - New test: empty outputs skip multi_epoch_end Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Made-with: Cursor
Same wrapping pattern as _validation_dl: wrap bare DataLoader into [DataLoader] at both single-value paths in resolve_test_dataloaders. Simplify isinstance guards in test_step_outputs and _get_num_dataloaders. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
75fc50d to
0194224
Compare
| def _get_num_dataloaders(self, tag: str = 'val'): | ||
| if tag == 'val': | ||
| num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1 | ||
| num_dataloaders = len(self._validation_dl) if self._validation_dl else 1 |
There was a problem hiding this comment.
technically I think you want
| num_dataloaders = len(self._validation_dl) if self._validation_dl else 1 | |
| num_dataloaders = len(self._validation_dl) if self._validation_dl is not None else 1 |
and similar everywhere else - make it clear this is a None check
There was a problem hiding this comment.
The truthiness check (if self._validation_dl) and if self._validation_dl is not None are equivalent.
because after resolver normalization, self._validation_dl is always either None or a non-empty List[DataLoader]. The None case is converted to [] by val_dataloader() (PTL 2.0+ doesn't accept None), but PTL then skips validation entirely for an empty list, so on_validation_start or _get_num_dataloaders never executes with [].
There was a problem hiding this comment.
Why are we setting num_dataloaders = 1 when the list is empty then?
pzelasko
left a comment
There was a problem hiding this comment.
Looks good in general, but before I approve - which models did you test it on, and did you test with both single and multi validation dataset?
nithinraok
left a comment
There was a problem hiding this comment.
LGTM in general. How are you validating the changes? CI only runs on single gpu have you tested on multiple GPUs with multiple nodes?
|
@pzelasko @nithinraok all questions regarding test by modeling are valid. In my experience,
If we want to proceed with the move, we may need to test EncDecCTCModel, EncDecRNNTModel, FastPitch, AudioToAudio, ..., on multi-val dataloaders on multi-gpus integration tests. Do we have such resources/tests on CI? |
I'm afraid we don't have CI tests for these conditions. I think we can use up to 2 GPUs in CI currently (cc @chtruong814 to verify). It might be a good idea to add at least one test multi-GPU multi-validation training of some representative model for ~10 training steps with 1 validation and parse the CLI output to verify it's OK (Magpie is probably OK, or maybe Parakeet as it's most popular?). |
What does this PR do ?
Unify
ModelPT.validation_step_outputs(andtest_step_outputs) to always return a list of lists, so a single dataloader is simply the N=1 case and subclasses no longer need to branch on the output shape. Normalize both_validation_dland_test_dltoOptional[List[DataLoader]]via their respective resolvers.Collection: Core, ASR, TTS, Audio
Changelog
modelPT.py:validation_step_outputs/test_step_outputsproperties always return[[] for _ in range(num_dl)];on_validation_epoch_end/on_test_epoch_enduselen() == 1instead ofisinstance(..., dict)for single-vs-multi dispatch; empty-output guard updated toall(len(o) == 0 for o in ...)since[[]]is truthy; empty dataloader buckets skipped in multi-DL loopmodel_utils.py:resolve_validation_dataloadersandresolve_test_dataloaderswrap bareDataLoaderinto[DataLoader]at both single-value paths, normalizing_validation_dland_test_dltoOptional[List[DataLoader]]modelPT.py(setup_multiple_validation_data): type annotation updated;isinstanceguard simplified to truthiness check after normalizationif/elsebranching invalidation_step/test_step; always useself.validation_step_outputs[dataloader_idx].append(...)transformer_bpe_models.py: removeisinstance(outputs[0], dict)normalization loop inmulti_validation_epoch_end— base class now iterates dataloaders and calls it once per DLaudio_to_audio.py: simplify_get_num_dataloaders(both val and test) and logging callback setup after normalizationfastpitch.py,magpietts.py,magpietts_preference_optimization.py: addRuntimeErrorguard forlen(validation_step_outputs) != 1; add early-return on empty outputs; useself.validation_step_outputs[0]consistentlyssl_models.py: fixEncDecMaskedTokenPredModel.test_step— was appending tovalidation_step_outputsinstead oftest_step_outputstest_ema.py,check_for_ranks.py,test_ptl_stateless_timer.py): overridemulti_validation_epoch_endinstead ofon_validation_epoch_end; base class handles iteration, clearing, and per-DL prefixUsage
No API changes for single-dataloader models —
dataloader_idx=0is the default. Subclasses should use the[dataloader_idx]indexing pattern: